{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Copy of NSL CARLS example.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZwZNOAMZcxl3"
      },
      "source": [
        "##### Copyright 2021 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "nxbcnXODdE06"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-uOfi_T-BtGG"
      },
      "source": [
        "This colab shows how to use CARLS to train a model with regularization more efficiently.\n",
        "It first trains the model with graph regularization from Neural-structured-learning, then shows how to use CARLS to train. you can compare the time spent per step to see the difference."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "17yrp6L7dE-2"
      },
      "source": [
        "# Install necessary packages"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Igb_j8Pee9PK"
      },
      "outputs": [],
      "source": [
        "!apt install curl gnupg\n",
        "!curl https://bazel.build/bazel-release.pub.gpg | apt-key add -\n",
        "!echo \"deb [arch=amd64] https://storage.googleapis.com/bazel-apt stable jdk1.8\" | tee /etc/apt/sources.list.d/bazel.list\n",
        "!apt update \u0026\u0026 apt install bazel-3.1.0\n",
        "!pip install neural-structured-learning\n",
        "!pip install protobuf==3.9.2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lMlp3ntZfNfs"
      },
      "outputs": [],
      "source": [
        "# Clear dir and download neural structured learning codebase\n",
        "!rm -rf nsl \u003e /dev/null\n",
        "!git clone https://github.com/tensorflow/neural-structured-learning.git nsl"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QuVA8CxjdNop"
      },
      "source": [
        "# Prepare dataset and build the package with bazel"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FqKn7b3AfOwL"
      },
      "outputs": [],
      "source": [
        "!cd /content/nsl \u0026\u0026 bash neural_structured_learning/examples/preprocess/cora/prep_data.sh\n",
        "\n",
        "from __future__ import absolute_import\n",
        "from __future__ import division\n",
        "from __future__ import print_function\n",
        "\n",
        "import os\n",
        "import threading\n",
        "from absl import app\n",
        "from absl import flags\n",
        "from absl import logging\n",
        "import attr\n",
        "\n",
        "import tensorflow as tf\n",
        "from google.protobuf import text_format\n",
        "\n",
        "NBR_FEATURE_PREFIX = 'NL_nbr_'\n",
        "NBR_WEIGHT_SUFFIX = '_weight'\n",
        "@attr.s\n",
        "class HParams(object):\n",
        "  ### dataset parameters\n",
        "  num_classes = attr.ib(default=7)\n",
        "  max_seq_length = attr.ib(default=1433)\n",
        "  ### NGM parameters\n",
        "  graph_regularization_multiplier = attr.ib(default=0.1)\n",
        "  num_neighbors = attr.ib(default=1)\n",
        "  ### model architecture\n",
        "  num_fc_units = attr.ib(default=[50, 50])\n",
        "  ### training parameters\n",
        "  train_epochs = attr.ib(default=10)\n",
        "  batch_size = attr.ib(default=128)\n",
        "  dropout_rate = attr.ib(default=0.5)\n",
        "  ### eval parameters\n",
        "  eval_steps = attr.ib(default=None)  # Every test instance is evaluated.\n",
        "\n",
        "hparams = HParams()\n",
        "NBR_FEATURE_PREFIX = 'NL_nbr_'\n",
        "NBR_WEIGHT_SUFFIX = '_weight'\n",
        "\n",
        "def make_dataset(file_path, training, include_nbr_features, hparams):\n",
        "  \"\"\"Returns a `tf.data.Dataset` instance based on data in `file_path`.\"\"\"\n",
        "\n",
        "  def parse_example(example_proto):\n",
        "    \"\"\"Extracts relevant fields from the `example_proto`.\n",
        "    Args:\n",
        "      example_proto: An instance of `tf.train.Example`.\n",
        "    Returns:\n",
        "      A pair whose first value is a dictionary containing relevant features\n",
        "      and whose second value contains the ground truth labels.\n",
        "    \"\"\"\n",
        "    # The 'words' feature is a multi-hot, bag-of-words representation of the\n",
        "    # original raw text. A default value is required for examples that don't\n",
        "    # have the feature.\n",
        "    feature_spec = {\n",
        "        'id':\n",
        "            tf.io.FixedLenFeature((), tf.string, default_value=''),\n",
        "        'words':\n",
        "            tf.io.FixedLenFeature([hparams.max_seq_length],\n",
        "                                  tf.int64,\n",
        "                                  default_value=tf.constant(\n",
        "                                      0,\n",
        "                                      dtype=tf.int64,\n",
        "                                      shape=[hparams.max_seq_length])),\n",
        "        'label':\n",
        "            tf.io.FixedLenFeature((), tf.int64, default_value=-1),\n",
        "    }\n",
        "\n",
        "    if include_nbr_features:\n",
        "      for i in range(hparams.num_neighbors):\n",
        "        nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')\n",
        "        nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i,\n",
        "                                         NBR_WEIGHT_SUFFIX)\n",
        "        nbr_id_key =  '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'id')\n",
        "        feature_spec[nbr_feature_key] = tf.io.FixedLenFeature(\n",
        "            [hparams.max_seq_length],\n",
        "            tf.int64,\n",
        "            default_value=tf.constant(\n",
        "                0, dtype=tf.int64, shape=[hparams.max_seq_length]))\n",
        "        feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(\n",
        "            [1], tf.float32, default_value=tf.constant([0.0]))\n",
        "        feature_spec[nbr_id_key] = tf.io.FixedLenFeature(\n",
        "            (), tf.string, default_value='')\n",
        "\n",
        "    features = tf.io.parse_single_example(example_proto, feature_spec)\n",
        "    labels = features.pop('label')\n",
        "    return features, labels\n",
        "\n",
        "  # If the dataset is sharded, the following code may be required:\n",
        "  # filenames = tf.data.Dataset.list_files(file_path, shuffle=True)\n",
        "  # dataset = filenames.interleave(load_dataset, cycle_length=1)\n",
        "  dataset = tf.data.TFRecordDataset([file_path])\n",
        "  if training:\n",
        "    dataset = dataset.shuffle(10000)\n",
        "  dataset = dataset.map(parse_example)\n",
        "  dataset = dataset.batch(hparams.batch_size)\n",
        "  return dataset\n",
        "\n",
        "train_dataset = make_dataset('/tmp/cora/train_merged_examples.tfr', True, True, hparams)\n",
        "test_dataset = make_dataset('/tmp/cora/test_examples.tfr', False, False, hparams)\n",
        "print(train_dataset)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1FcN-9Nne4MA"
      },
      "source": [
        "# Bazel build and run the example.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nZhj65xfe9m7"
      },
      "outputs": [],
      "source": [
        "!cd /content/nsl \u0026\u0026 bazel-3.1.0 build research/carls/examples/graph_regularization:graph_keras_mlp_cora"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CdawT0YPws7V"
      },
      "outputs": [],
      "source": [
        "%cd /content/nsl/bazel-bin/research/carls/examples/graph_regularization/graph_keras_mlp_cora.runfiles/org_tensorflow_neural_structured_learning"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lf-DY18Veiid"
      },
      "outputs": [],
      "source": [
        "import neural_structured_learning as nsl\n",
        "from research.carls import dynamic_embedding_config_pb2 as de_config_pb2\n",
        "from research.carls import dynamic_embedding_neighbor_cache as de_nb_cache\n",
        "from research.carls import graph_regularization\n",
        "from research.carls import kbs_server_helper_pybind as kbs_server_helper"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Kxi733fbf3TT"
      },
      "source": [
        "Define the model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zQOnvAZPzi2d"
      },
      "outputs": [],
      "source": [
        "def build_model(hparams):\n",
        "  model = tf.keras.Sequential()\n",
        "  model.add(\n",
        "      tf.keras.layers.InputLayer(\n",
        "          input_shape=(hparams.max_seq_length,), name='words'))\n",
        "  # Input is already one-hot encoded in the integer format. We cast it to\n",
        "  # floating point format here.\n",
        "  model.add(\n",
        "      tf.keras.layers.Lambda(lambda x: tf.keras.backend.cast(x, tf.float32)))\n",
        "  for num_units in hparams.num_fc_units:\n",
        "    model.add(tf.keras.layers.Dense(num_units, activation='relu'))\n",
        "    model.add(tf.keras.layers.Dropout(hparams.dropout_rate))\n",
        "  model.add(tf.keras.layers.Dense(hparams.num_classes, activation='softmax'))\n",
        "  return model\n",
        "  \n",
        "def train_and_evaluate(model, path, train_dataset, test_dataset, hparams):\n",
        "  \"\"\"Compiles, trains, and evaluates a `Keras` model.\n",
        "  Args:\n",
        "    model: An instance of `tf.Keras.Model`.\n",
        "    path: The path of a checkpoint that is used to update embeddings.\n",
        "    train_dataset: An instance of `tf.data.Dataset` representing training data.\n",
        "    test_dataset: An instance of `tf.data.Dataset` representing test data.\n",
        "    hparams: An instance of `Hparams`.\n",
        "  \"\"\"\n",
        "  model.compile(\n",
        "      optimizer='adam',\n",
        "      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),\n",
        "      metrics=['accuracy'])\n",
        "  callbacks = []\n",
        "  if path:\n",
        "    callbacks = [tf.keras.callbacks.ModelCheckpoint(\n",
        "        '/tmp/carls/output/weights-{epoch:02d}',\n",
        "        save_weights_only=True,\n",
        "        save_freq='epoch'\n",
        "    )]\n",
        "  model.fit(\n",
        "      train_dataset,\n",
        "      epochs=hparams.train_epochs,\n",
        "      verbose=1,\n",
        "      callbacks=callbacks)\n",
        "  eval_results = dict(\n",
        "      zip(model.metrics_names,\n",
        "          model.evaluate(test_dataset, steps=hparams.eval_steps)))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dTpOljPA8DXm"
      },
      "outputs": [],
      "source": [
        "base_model = build_model(hparams)\n",
        "train_and_evaluate(base_model, None, train_dataset, test_dataset, hparams)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "N-9gisTX1TMx"
      },
      "source": [
        "Add configuration for knowledge bank that stores dynamic embeddings\n",
        "and start a server"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AbRXI1YaRVCc"
      },
      "outputs": [],
      "source": [
        "de_config = text_format.Parse(\n",
        "  \"\"\"\n",
        "    embedding_dimension: %d\n",
        "    knowledge_bank_config {\n",
        "      initializer {\n",
        "        random_uniform_initializer {\n",
        "          low: -0.5\n",
        "          high: 0.5\n",
        "        }\n",
        "        use_deterministic_seed: true\n",
        "      }\n",
        "      extension {\n",
        "        [type.googleapis.com/carls.InProtoKnowledgeBankConfig] {}\n",
        "      }\n",
        "    }\n",
        "  \"\"\" % hparams.num_classes, de_config_pb2.DynamicEmbeddingConfig())\n",
        "\n",
        "options = kbs_server_helper.KnowledgeBankServiceOptions(True, -1, 10)\n",
        "kbs_server = kbs_server_helper.KbsServerHelper(options)\n",
        "kbs_address = 'localhost:%d' % kbs_server.port()\n",
        "client = de_nb_cache.DynamicEmbeddingNeighborCache(\n",
        "        'id', de_config, kbs_address, timeout_ms=10 * 1000)\n",
        "print(kbs_address)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ogGH6vce1bHE"
      },
      "source": [
        "Background thread to load checkpoint and update embeddings"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pgqWFC69nZT-"
      },
      "outputs": [],
      "source": [
        "\n",
        "def update_embeddings(\n",
        "    stop_event, model_path, model, neighbor_cache_client, dataset):\n",
        "  \"\"\"Updates embeddings in knowledge bank server.\n",
        "  This runs in a background thread, loads the latest unused checkpoint,\n",
        "  then runs inference on all available data, updates corresponding embeddings\n",
        "  in the knowledge bank server.\n",
        "  Args:\n",
        "    stop_event: `threading.Event` object used to stop updating the embeddings.\n",
        "    model_path: path to the saved model.\n",
        "    model: `GraphRegularizationWithCaching` object for inference.\n",
        "    neighbor_cache_client: `NeighborCacheClient` object to update embeddings.\n",
        "    dataset: dataset for inference to update.\n",
        "  \"\"\"\n",
        "  logging.info('Start embedding updates')\n",
        "  last_used_ckpt_path = None\n",
        "\n",
        "  def update_embedding_fn(features, label):\n",
        "    if not stop_event.wait(0.01):\n",
        "      neighbor_cache_client.update(\n",
        "          features['id'], model.base_model(features))\n",
        "    return features, label\n",
        "\n",
        "  # Keep updating until stop_event set to True.\n",
        "  while not stop_event.wait(0.1):\n",
        "    # Keep waiting to load the latest checkpoint until it exists.\n",
        "    while not stop_event.wait(0.1):\n",
        "      try:\n",
        "        latest_ckpt_path = tf.train.latest_checkpoint(model_path)\n",
        "        if last_used_ckpt_path != latest_ckpt_path:\n",
        "          # No new checkpoint since the last update.\n",
        "          last_used_ckpt_path = latest_ckpt_path\n",
        "          model.load_weights(latest_ckpt_path)\n",
        "          break\n",
        "      except tf.errors.NotFoundError:\n",
        "        pass\n",
        "\n",
        "    # Run inference on the dataset and update embeddings.\n",
        "    dataset.map(update_embedding_fn)\n",
        "\n",
        "  logging.info('Finished embedding updates')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9UQjtv801jg5"
      },
      "source": [
        "Build a model for training with graph regularization"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "N4JkjQUA0Co8"
      },
      "outputs": [],
      "source": [
        "graph_reg_config = nsl.configs.make_graph_reg_config(\n",
        "      max_neighbors=hparams.num_neighbors,\n",
        "      multiplier=hparams.graph_regularization_multiplier,\n",
        "      distance_type=nsl.configs.DistanceType.L2,\n",
        "      sum_over_axis=-1)\n",
        "graph_reg_model = graph_regularization.GraphRegularizationWithCaching(\n",
        "      base_model, graph_reg_config, client)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5JlmSQaR1rJY"
      },
      "source": [
        "Train"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "i0hv_nxr1EqF"
      },
      "outputs": [],
      "source": [
        "stop_event = threading.Event()\n",
        "base_inference_model = build_model(hparams)\n",
        "graph_inference_model = graph_regularization.GraphRegularizationWithCaching(\n",
        "    base_inference_model, graph_reg_config, client)\n",
        "update_thread = threading.Thread(\n",
        "    target=update_embeddings,\n",
        "    args=(stop_event, '/tmp/carls/output', \n",
        "          graph_inference_model, client, train_dataset))\n",
        "update_thread.daemon = True\n",
        "update_thread.start()\n",
        "\n",
        "train_and_evaluate(graph_reg_model, '/tmp/carls/output',\n",
        "                   train_dataset, test_dataset, hparams)\n",
        "stop_event.set()\n",
        "kbs_server.Terminate()\n",
        "update_thread.join()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MjBmves1-gQi"
      },
      "outputs": [],
      "source": [
        "!rm -rf /tmp/carls\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "NSL CARLS example.ipynb",
      "private_outputs": true,
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
